
import json
import datasets
from fire import Fire
from functools import partial
from typing import List
from loguru import logger
import os
import openai
import numpy as np
import torch
from utils import (
    generate_together,
    generate_openai,
    generate_with_references,
    DEBUG,
)

IDENTIFIERS = {"ABC":["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q","R","S","T","U","V","W","X","Y","Z"],"123":["1","2","3","4","5","6","7","8","9"], "ABC_RAND1":['Y', 'P', 'E', 'R', 'C', 'A', 'X', 'U', 'D', 'N', 'O', 'Z', 'V', 'T', 'L', 'M', 'S', 'B', 'K', 'F', 'I', 'Q', 'G', 'H', 'J', 'W']}
def create_evaluation_prompt(instruction, references,prompt_identifier, prompt_template):
    prompt = prompt_template.replace('{instruction}',instruction)
    for idx, identifier in enumerate(prompt_identifier):
        name = f"identifier_{idx+1}"
        prompt = prompt.replace(f'{{{name}}}', identifier)
    for idx, (model, reference) in enumerate(references):
        name = f"output_{idx+1}"
        prompt = prompt.replace(f'{{{name}}}', reference)
    messages = [{'role': 'user', 'content': prompt}]
    return messages

def pairwise_process_judgement(judgment):
    if "[[A]]" in judgment:
        return "A"
    elif "[[B]]" in judgment:
        return "B"

def process_fn(
    item, 
    model, 
    reference_models = [],
    temperature=0.7,
    max_tokens=1,
    prompt="",
    prompt_identifier="ABC",
    rounds=1,
    n=1,
    mode="Qwen_listwise",
    ranker=None,
    tokenizer=None
):
    
    messages = item.get('messages', [])
    messages_list = list(messages.items())

    if mode == "Qwen_listwise":
        prompt_identifier = IDENTIFIERS[prompt_identifier][:len(references)]
        messages = create_evaluation_prompt(item['instruction'], references_list, prompt_identifier,prompt)
        if DEBUG:
            logger.debug(f"messages: {messages}")
                    
        if "gpt" in model:
            outputs = generate_openai(
                model=model,
                messages=messages,
                n=n,
                logprobs=True,
                max_tokens=max_tokens,
                temperature=0.0
            )
            inputs = outputs["inputs"]
            logprobs = outputs["logprobs"].content[0]
            logprobs = {"tokens":logprobs.token, "token_logprobs":logprobs.logprob}
            outputs = outputs["outputs"]
        else:
            outputs = generate_together(
                model=model,
                messages=messages,
                n=n,
                logprobs=1,
                max_tokens=max_tokens,
            )
            inputs = outputs["inputs"]
            logprobs = outputs["logprobs"]
            outputs = outputs["outputs"]
        logger.info(f"logprobs: {logprobs}")
        

        picked_idx = 0
        picked_token = logprobs["tokens"]
        logprob = logprobs["token_logprobs"]
        # find the index of picked in prompt_identifier which is a list of tokens
        logger.info(f"picked_token: {picked_token}")
        for i, token in enumerate(prompt_identifier):
            if token.lower() == picked_token[0].lower():
                picked_idx = i
                break
        else:
            logger.error("Token not found in prompt_identifier use default 0")
        if DEBUG:
            logger.debug(f"picked_idx: {picked_idx}")
            logger.debug(f"picked_token: {prompt_identifier}")
        from_model = references_list[picked_idx][0]
        outputs = [references_list[picked_idx][1]]
    elif mode == "Qwen_pairwise":
        systemp_prompt, user_prompt = prompt.split("#######################")
        systemp_prompt, user_prompt = systemp_prompt.strip(), user_prompt.strip()
        reference_scores = [0] * len(references_list)
        def get_judgement(model, reference_pair):
            user_message = create_evaluation_prompt(item['instruction'], reference_pair, [], user_prompt)
            messages = [{'role': 'system', 'content': systemp_prompt}] + user_message
            if "gpt" in model:
                outputs = generate_openai(
                    model=model,
                    messages=messages,
                    n=n,
                    max_tokens=512,
                    temperature=0.0
                )
            else:
                outputs = generate_together(
                    model=model,
                    messages=messages,
                    n=n,
                    max_tokens=512,
                    temperature=0.0,
                )
            inputs = outputs["inputs"]
            outputs = outputs["outputs"][0]
            judgment = pairwise_process_judgement(outputs)
            return judgment
        for i, model_reference in enumerate(references_list):
            for j in range(i+1, len(references_list)):
                reference_pair = [references_list[i], references_list[j]]
                judgment = get_judgement(model, reference_pair)
                if judgment == "A":
                    reference_scores[i] += 0.5
                elif judgment == "B":
                    reference_scores[j] += 0.5
                judgment = get_judgement(model, reference_pair[::-1])
                if judgment == "A":
                    reference_scores[j] += 0.5
                elif judgment == "B":
                    reference_scores[i] += 0.5
        inputs=[[]]
        logger.info(f"reference_scores: {reference_scores}")
        picked_idx = reference_scores.index(max(reference_scores))
        from_model = references_list[picked_idx][0]
        outputs = [references_list[picked_idx][1]]
    elif mode == "pairRM_ranking":
        # https://huggingface.co/llm-blender/PairRM
        inputs = [[]]
        candidates_texts = [message[-1]["content"] for model, message in messages_list]
        instruction = messages_list[0][1][0]['content']
        ranks = ranker.rank([instruction], [candidates_texts], return_scores=False, batch_size=1)
        logger.info(f"ranks: {ranks}")
        ranks = ranks[0].tolist()
        picked_idx = ranks.index(min(ranks))
        reject_idx = ranks.index(max(ranks))
        logger.info(f"picked_idx: {picked_idx}, reject_idx: {reject_idx}")

        chosen = messages_list[picked_idx][1]
        reject = messages_list[reject_idx][1]
        chosen_score = 0
        reject_score = 0
        reference_scores = [0] * len(messages_list)

    elif mode == "pairRM_bestofn":
        # https://huggingface.co/llm-blender/PairRM
        pass
    elif mode == "pointwise_classifier":
        # https://huggingface.co/RLHFlow/ArmoRM-Llama3-8B-v0.1
        inputs = [[]]
        reference_scores = []
        for model, message in messages_list:
            input_ids = tokenizer.apply_chat_template(message, return_tensors="pt").to("cuda")
            with torch.no_grad():
                output = ranker(input_ids)
                multi_obj_rewards = output.rewards.cpu().float() 
                helpsteer_rewards_pred = multi_obj_rewards[0, :5]
                preference_score1 = sum(helpsteer_rewards_pred).item()
                preference_score2 = helpsteer_rewards_pred.mean().item()
                preference_score = output.score.cpu().float().item()
                reference_scores.append(preference_score)
        logger.info(f"reference_scores: {reference_scores}")
        chosen_idx = reference_scores.index(max(reference_scores))
        reject_idx = reference_scores.index(min(reference_scores))
        logger.info(f"chosen_idx: {chosen_idx}, reject_idx: {reject_idx}")

        chosen = messages_list[chosen_idx][1]
        reject = messages_list[reject_idx][1]
        chosen_score = reference_scores[chosen_idx]
        reject_score = reference_scores[reject_idx]

    elif mode == "Skywork-Reward":
        # https://huggingface.co/Skywork/Skywork-Reward-Gemma-2-27B
        inputs = [[]]
        reference_scores = []
        for model, message in messages_list:
            input_ids = tokenizer.apply_chat_template(message, tokenize=False)
            input_ids = tokenizer(input_ids, return_tensors="pt").to("cuda")
            with torch.no_grad():
                score1 = ranker(**input_ids).logits[0][0].item()
                reference_scores.append(score1)
        logger.info(f"reference_scores: {reference_scores}")
        chosen_idx = reference_scores.index(max(reference_scores))
        reject_idx = reference_scores.index(min(reference_scores))
        logger.info(f"chosen_idx: {chosen_idx}, reject_idx: {reject_idx}")

        chosen = messages_list[chosen_idx][1]
        reject = messages_list[reject_idx][1]
        chosen_score = reference_scores[chosen_idx]
        reject_score = reference_scores[reject_idx]

    elif mode == "pair-preference-model-LLaMA3-8B":
        inputs = [[]]
        # references_model_subset = ["WizardLM-2-8x22B-together", "Llama-3-70b-chat-hf-together"]
        # references_list = [reference for reference in references_list if reference[0] in references_model_subset]
        tokenizer, tokenizer_plain = tokenizer
        prompt_template = "[CONTEXT] {context} [RESPONSE A] {response_A} [RESPONSE B] {response_B} \n"
        token_id_A = tokenizer.encode("A", add_special_tokens=False)
        token_id_B = tokenizer.encode("B", add_special_tokens=False)
        assert len(token_id_A) == 1 and len(token_id_B) == 1
        token_id_A = token_id_A[0]
        token_id_B = token_id_B[0]


        ## We can also handle multi-turn conversation.
        instruction = [{"role": "user", "content": item['instruction']}]
        context = tokenizer_plain.apply_chat_template(instruction, tokenize=False)
        reference_scores = [0] * len(references_list)
        for i, model_reference in enumerate(references_list):
            for j in range(i+1, len(references_list)):
                response_chosen = references_list[i][1]
                response_rejected = references_list[j][1]
                responses = [response_chosen, response_rejected]
                probs_chosen = []
                for chosen_position in [0, 1]:
                    # we swap order to mitigate position bias
                    response_A = responses[chosen_position]
                    response_B = responses[1 - chosen_position]
                    prompt = prompt_template.format(context=context, response_A=response_A, response_B=response_B)
                    message = [
                        {"role": "user", "content": prompt},
                    ]

                    input_ids = tokenizer.encode(tokenizer.apply_chat_template(message, tokenize=False).replace(tokenizer.bos_token, ""), return_tensors='pt', add_special_tokens=False).cuda() 

                    with torch.no_grad():
                        output = ranker(input_ids)
                    logit_A = output.logits[0, -1, token_id_A].item()
                    logit_B = output.logits[0, -1, token_id_B].item()
                    # take softmax to get the probability; using numpy
                    Z = np.exp(logit_A / 1) + np.exp(logit_B / 1)
                    logit_chosen = [logit_A, logit_B][chosen_position]
                    prob_chosen = np.exp(logit_chosen / 1) / Z
                    probs_chosen.append(prob_chosen)
                avg_prob_chosen = np.mean(probs_chosen)
                reference_scores[i] += avg_prob_chosen
                reference_scores[j] += 1 - avg_prob_chosen
        picked_idx = reference_scores.index(max(reference_scores))
        logger.info(f"reference_scores: {reference_scores}, picked_idx: {picked_idx}")
        from_model = references_list[picked_idx][0]
        outputs = [references_list[picked_idx][1]]

    print(reference_scores, chosen, reject, chosen_score, reject_score)

    return {
        "reference_scores": reference_scores, "chosen": chosen, "reject": reject, "chosen_score": chosen_score, "reject_score": reject_score
    }


def main(
    model: str,
    output_path: str,
    additional_info: str = "",
    reference_paths: str = None,
    promtp_path: str = None,
    prompt_identifier: str = "ABC",
    temperature: float = 0.7,
    max_tokens: int = 1,
    num_proc: int = 16,
    rounds=1,
    n=1,
    mode="Qwen_listwise"
):
    ranker = None
    tokenizer = None
    if model == "llm-blender/PairRM":
        import llm_blender
        ranker = llm_blender.Blender()
        ranker.loadranker("llm-blender/PairRM")
    elif model == "RLHFlow/ArmoRM-Llama3-8B-v0.1" or model == "Skywork/Skywork-Reward-Gemma-2-27B":
        from transformers import AutoModelForSequenceClassification, AutoTokenizer
        ranker = AutoModelForSequenceClassification.from_pretrained(model, device_map="auto", 
                               trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",
)
        tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
    elif model == "RLHFlow/pair-preference-model-LLaMA3-8B":
        from transformers import AutoModelForCausalLM, AutoTokenizer
        ranker = AutoModelForCausalLM.from_pretrained(model,
                                             torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2").cuda()
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", use_fast=True)
        tokenizer_plain = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", use_fast=True)
        tokenizer_plain.chat_template = "\n{% for message in messages %}{% if loop.index0 % 2 == 0 %}\n\n<turn> user\n {{ message['content'] }}{% else %}\n\n<turn> assistant\n {{ message['content'] }}{% endif %}{% endfor %}\n\n\n"
        tokenizer = [tokenizer, tokenizer_plain]

    if reference_paths is None:
        reference_paths = []
    else:
        if "*" in reference_paths:
            import glob
            reference_paths = glob.glob(reference_paths)
            reference_paths = sorted(reference_paths)
        else:
            reference_paths = reference_paths.split(',')


    with open(reference_paths[0]) as f:
        main = json.load(f)
    # from jsonl
    eval_set = datasets.Dataset.from_list(main)
    eval_set = eval_set.remove_columns(['messages'])
    try:
        eval_set = eval_set.remove_columns(['generator'])
    except Exception as e:
        pass

    if len(reference_paths):
        num_reference_path = len(reference_paths)
        reference_paths = reference_paths[:num_reference_path]
        logger.info(f"`reference_paths` provided: {reference_paths}")        

        references = []
        for idx, reference_path in enumerate(reference_paths):
            with open(reference_path) as f:
                reference_responses = json.load(f)
                logger.info(f"Reading reference outputs: {reference_path} ({len(reference_responses)})")
                # ref_model = reference_responses[0]['generator'].split("/")[-1]
                for i_reference_response, reference_response in enumerate(reference_responses):
                    if len(references) <= i_reference_response:
                        references.append({str(idx): reference_response['messages']})
                    else:
                        references[i_reference_response][str(idx)] = reference_response['messages']

        eval_set = eval_set.add_column(f"messages", references)
    
    logger.info(f"Start.")
    # eval_set = eval_set.select(range(3))
    logger.info(eval_set)

    with open(promtp_path) as f:
        prompt = f.read()

    eval_set = eval_set.map(
        partial(
            process_fn, 
            model=model, 
            temperature=temperature,
            max_tokens=max_tokens,
            rounds=rounds,
            prompt=prompt,
            prompt_identifier=prompt_identifier,
            n=n,
            mode=mode,
            ranker=ranker,
            tokenizer=tokenizer,
        ),
        batched=False, num_proc=num_proc,
    )
    model_name = model.split('/')[-1]
    output_dir = f'{output_path}/{model_name}/'
    os.makedirs(output_dir, exist_ok=True)
    # print(eval_set)


    try:
        eval_set = eval_set.remove_columns(f"score_chosen")
        eval_set = eval_set.remove_columns(f"score_rejected")
        eval_set = eval_set.remove_columns(f"references")
    except Exception as e:
        pass
    eval_set_list = list(eval_set)

    prompt_file = promtp_path.split('/')[-1].split('.')[0]
    output_path = f'{output_dir}/{model_name}-round_{rounds}-temp{temperature}-prompt{prompt_file}-{prompt_identifier}-{mode}-{additional_info}.json'

    
    logger.info(f"Saving outputs to {output_path}.")



    with open(output_path, 'w') as f:
        json.dump(eval_set_list, f, indent=2)


if __name__ == '__main__':

    Fire(main)